Skip to content

Comments

feat: add serverless memory API functions and memory system package#800

Open
inoribea wants to merge 3 commits intomoeru-ai:mainfrom
inoribea:pr-memory-api
Open

feat: add serverless memory API functions and memory system package#800
inoribea wants to merge 3 commits intomoeru-ai:mainfrom
inoribea:pr-memory-api

Conversation

@inoribea
Copy link

Summary

This PR adds serverless memory API functions and memory system package to provide both short-term and long-term memory capabilities for AIRI.

Details

  • Implements serverless functions for memory operations (save, search, clear)
  • Provides support for multiple memory providers (Vercel KV, Upstash Redis, PostgreSQL with pgvector, Qdrant)
  • Adds memory embedding functionality with support for OpenAI and Cloudflare embeddings
  • Creates a unified memory system package with provider abstraction
  • Includes migration scripts for database setup

Testing

  1. Verify that memory save API endpoint works correctly
  2. Test memory search functionality with different providers
  3. Confirm short-term memory operations (Vercel KV, Upstash Redis)
  4. Verify long-term memory operations (PostgreSQL, Qdrant)
  5. Test embedding functionality with various providers

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @inoribea, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances AIRI's conversational capabilities by implementing a robust and flexible memory system. It provides distinct mechanisms for short-term and long-term memory, allowing for efficient storage and retrieval of conversational context. The system is designed with extensibility in mind, supporting multiple backend providers for both memory types and integrating with various embedding services for advanced semantic search, all exposed through a set of serverless API functions.

Highlights

  • Serverless Memory API: Introduces dedicated serverless functions for managing AI memory, including saving, searching, and clearing conversational context.
  • Multi-Provider Support: Offers flexible short-term memory options (Vercel KV, Upstash Redis) and long-term memory solutions (PostgreSQL with pgvector, Qdrant).
  • Advanced Embedding Integration: Supports OpenAI and Cloudflare for generating embeddings, which are crucial for semantic search capabilities in long-term memory.
  • Unified Memory System Package: A new @proj-airi/memory-system package provides an abstracted and unified interface for memory management, simplifying integration and extensibility.
  • Database Migration Scripts: Includes SQL migration scripts for setting up PostgreSQL tables with pgvector for efficient vector storage and search, along with detailed setup instructions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive memory system for AIRI, with both standalone serverless functions and a more robust, reusable package. It supports multiple providers for short-term (Vercel KV, Upstash Redis) and long-term (Postgres, Qdrant) memory, which is great for flexibility.

My review focuses on improving robustness, security, and maintainability. Key points include:

  • Addressing a critical race condition in the short-term memory implementation in api/_lib/memory.ts.
  • Fixing a critical bug in the new packages/memory factory that prevents the Qdrant provider from being used.
  • Enhancing input validation in the API handlers to prevent runtime errors from malformed requests.
  • Suggestions for refactoring duplicated code and improving error handling.

Overall, this is a significant and well-structured feature addition. Addressing these points will make the memory system more reliable and easier to maintain.

Comment on lines +293 to +295
if (config.provider !== 'postgres-pgvector') {
throw new Error(`Unsupported long-term memory provider: ${config.provider}`)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a logic error in createLongTermFromConfig that prevents the qdrant provider from being configured. The code throws an error if the provider is not postgres-pgvector, making the subsequent logic block for qdrant unreachable. This check should be removed to allow for multiple long-term providers.

Comment on lines +541 to +551
const redis = getRedisClient()
const existing = await redis.get<Message[]>(key) || []
const updated = [...existing, normalizedMessage]

// Keep only last N messages
const maxMessages = config.shortTerm.maxMessages || 20
const trimmed = updated.slice(-maxMessages)

// Save with TTL
const ttl = config.shortTerm.ttlSeconds || 1800
await redis.setex(key, ttl, JSON.stringify(trimmed))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for saving messages to Upstash Redis uses a non-atomic read-modify-write pattern (get, modify array, setex). This can lead to a race condition and message loss if multiple requests to save a message for the same session arrive concurrently.

A more robust approach is to use Redis's atomic list operations. I suggest using a pipeline with rpush to add the message and ltrim to keep the list at the desired size. This change will also require updating getRecentMessages to use lrange.

Suggested change
const redis = getRedisClient()
const existing = await redis.get<Message[]>(key) || []
const updated = [...existing, normalizedMessage]
// Keep only last N messages
const maxMessages = config.shortTerm.maxMessages || 20
const trimmed = updated.slice(-maxMessages)
// Save with TTL
const ttl = config.shortTerm.ttlSeconds || 1800
await redis.setex(key, ttl, JSON.stringify(trimmed))
const redis = getRedisClient()
const maxMessages = config.shortTerm.maxMessages || 20
const ttl = config.shortTerm.ttlSeconds || 1800
const pipeline = redis.pipeline()
pipeline.rpush(key, JSON.stringify(normalizedMessage))
pipeline.ltrim(key, -maxMessages, -1)
if (ttl > 0) {
pipeline.expire(key, ttl)
}
await pipeline.exec()

Comment on lines +571 to +575
const redis = getRedisClient()
const data = await redis.get<string>(key)
if (data) {
messages = typeof data === 'string' ? JSON.parse(data) : data
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To complement the change to atomic list operations in saveMessage, getRecentMessages for Upstash Redis should be updated to use lrange to fetch the list of messages and then parse each item.

Suggested change
const redis = getRedisClient()
const data = await redis.get<string>(key)
if (data) {
messages = typeof data === 'string' ? JSON.parse(data) : data
}
const redis = getRedisClient()
const data = await redis.lrange(key, 0, -1)
if (data && data.length > 0) {
messages = data.map(item => JSON.parse(item as string))
}

Comment on lines +36 to +40
if (!body?.message || typeof body.message !== 'object') {
return res.status(400).json({ success: false, error: 'message payload is required' })
}

await saveMessage(body.sessionId, body.message as any, body.userId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The validation for the message object is minimal, and it's then cast to any, which is unsafe. This can lead to runtime errors if an invalid message object is sent. Please add more specific validation to ensure the message object conforms to the Message interface before passing it to saveMessage.

    if (
      !body?.message
      || typeof body.message !== 'object'
      || !('role' in body.message)
      || typeof (body.message as any).role !== 'string'
      || !('content' in body.message)
    ) {
      return res.status(400).json({ success: false, error: 'message payload is required and must be an object with role and content' })
    }

    await saveMessage(body.sessionId, body.message as any, body.userId)

Comment on lines +21 to +29
if (req.method === 'POST') {
const body = req.body

if (!body) {
return res.status(400).json({ success: false, error: 'Configuration payload is required' })
}

setConfiguration(body)
return res.status(200).json({ success: true })
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The POST handler accepts any request body and passes it to setConfiguration without validation. An invalid configuration object could break the memory system for this serverless function instance until it's restarted. You should add validation to ensure the request body conforms to the MemoryConfiguration interface before setting it. Using a validation library like zod would be a robust way to handle this.

openaiClient = null

if (externalPostgresPool) {
externalPostgresPool.end().catch(() => {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error from externalPostgresPool.end() is being ignored. While errors during pool termination might not always be critical, it's good practice to at least log them for debugging purposes. This can help identify underlying issues with the database connection or environment.

Suggested change
externalPostgresPool.end().catch(() => {})
externalPostgresPool.end().catch(err => console.error('Failed to close external Postgres pool:', err))


const openaiConfig = embedding.openai ?? {
apiKey: embedding.apiKey ?? '',
baseURL: embedding.baseUrl ?? (embedding as Record<string, unknown>).baseURL as string | undefined,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type cast (embedding as Record<string, unknown>).baseURL as string | undefined is unnecessary and makes the code harder to read. The MemoryConfiguration interface already defines baseUrl as an optional string property on embedding. You can access it directly with embedding.baseUrl.

Suggested change
baseURL: embedding.baseUrl ?? (embedding as Record<string, unknown>).baseURL as string | undefined,
baseURL: embedding.baseUrl,

if (provider === 'openai' || provider === 'openai-compatible') {
const openai = getOpenAIClient()
const response = await openai.embeddings.create({
model: model || 'text-embedding-3-small',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default embedding model 'text-embedding-3-small' is hardcoded here. To improve maintainability and avoid magic strings, it's better to define this as a constant at the top of the file.

Example:

const DEFAULT_OPENAI_EMBEDDING_MODEL = 'text-embedding-3-small';

// ... later in generateEmbedding
model: model || DEFAULT_OPENAI_EMBEDDING_MODEL,

Comment on lines 205 to 348
private async generateEmbedding(input: string): Promise<number[]> {
const trimmed = input.trim()
if (trimmed.length === 0) {
return Array.from({ length: this.vectorSize }, () => 0)
}

let embedding: number[] = []

if (this.embeddingConfig.provider === 'cloudflare') {
embedding = await this.generateCloudflareEmbedding(trimmed)
}
else {
if (!this.openai) {
throw new Error('OpenAI-compatible client is not configured for embeddings.')
}

const response = await this.openai.embeddings.create({
model: this.embeddingConfig.model ?? DEFAULT_EMBEDDING_MODEL,
input: trimmed,
})

embedding = response.data?.[0]?.embedding ?? []
}

if (!embedding.length) {
throw new Error('Failed to generate embedding for memory content.')
}

if (embedding.length !== this.vectorSize) {
throw new Error(`Embedding dimension mismatch. Expected ${this.vectorSize}, received ${embedding.length}.`)
}

return embedding.map((value: number | string) => Number(value))
}

private extractUserId(message: Message): string | null {
const metadata = message.metadata ?? {}
const candidate = (metadata as Record<string, unknown>).userId
?? (metadata as Record<string, unknown>).userID
?? (metadata as Record<string, unknown>).user_id

return typeof candidate === 'string' && candidate.length > 0 ? candidate : null
}

private extractSessionId(message: Message): string | null {
const metadata = message.metadata ?? {}
const candidate = (metadata as Record<string, unknown>).sessionId
?? (metadata as Record<string, unknown>).sessionID
?? (metadata as Record<string, unknown>).session_id

return typeof candidate === 'string' && candidate.length > 0 ? candidate : null
}

private resolveEmbeddingConfiguration(
config?: EmbeddingProviderConfiguration,
): EmbeddingProviderConfiguration {
const provider = config?.provider
?? (env.MEMORY_EMBEDDING_PROVIDER as EmbeddingProviderConfiguration['provider'] | undefined)
?? 'openai'

const apiKey = config?.apiKey
?? env.MEMORY_EMBEDDING_API_KEY
?? env.OPENAI_API_KEY

if (!apiKey) {
throw new Error('An embedding API key is required.')
}

const model = config?.model
?? env.MEMORY_EMBEDDING_MODEL
?? DEFAULT_EMBEDDING_MODEL

const baseUrl = config?.baseUrl ?? env.MEMORY_EMBEDDING_BASE_URL
const accountId = config?.accountId ?? env.CLOUDFLARE_ACCOUNT_ID

if (provider === 'cloudflare' && !accountId) {
throw new Error('Cloudflare embedding provider requires an account ID.')
}

return {
provider,
apiKey,
model,
baseUrl,
accountId,
} satisfies EmbeddingProviderConfiguration
}

private async generateCloudflareEmbedding(input: string): Promise<number[]> {
const accountId = this.embeddingConfig.accountId
if (!accountId) {
throw new Error('Cloudflare account ID is not configured.')
}

const url = `${this.embeddingConfig.baseUrl ?? 'https://api.cloudflare.com/client/v4'}/accounts/${accountId}/ai/run/${this.embeddingConfig.model ?? DEFAULT_EMBEDDING_MODEL}`

const response = await fetch(url, {
method: 'POST',
headers: {
'Authorization': `Bearer ${this.embeddingConfig.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({ text: input }),
})

if (!response.ok) {
const body = await response.text()
throw new Error(`Cloudflare embedding request failed: ${response.status} ${body}`)
}

const payload = await response.json() as Record<string, any>
const embedding = this.extractCloudflareEmbedding(payload)

if (!embedding || !Array.isArray(embedding) || embedding.length === 0) {
throw new Error('Cloudflare embedding response did not include an embedding vector.')
}

return embedding.map((value: number | string) => Number(value))
}

private extractCloudflareEmbedding(payload: Record<string, any>): number[] | undefined {
const result = payload.result ?? payload

if (Array.isArray(result?.data) && result.data.length > 0) {
const candidate = result.data[0]
if (Array.isArray(candidate?.embedding)) {
return candidate.embedding as number[]
}
if (Array.isArray(candidate?.vector)) {
return candidate.vector as number[]
}
}

if (Array.isArray(result?.embedding)) {
return result.embedding as number[]
}

if (Array.isArray(result?.data?.embedding)) {
return result.data.embedding as number[]
}

return undefined
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between qdrant.provider.ts and postgres-pgvector.provider.ts for embedding generation logic (e.g., generateEmbedding, resolveEmbeddingConfiguration, generateCloudflareEmbedding, extractCloudflareEmbedding).

To improve maintainability, I recommend refactoring this shared logic into a separate EmbeddingGenerator class or a set of utility functions within the package (e.g., in a src/utils/embedding.ts file). Both long-term providers could then instantiate or import and use this shared component.

@shinohara-rin
Copy link
Contributor

As a heads-up there're plans to implement the memory system as plugins, please refer to #255 and #520

@github-actions
Copy link
Contributor

⏳ Approval required for deploying to Cloudflare Workers (Preview) for stage-web.

Name Link
🔭 Waiting for approval For maintainers, approve here

Hey, @nekomeowww, @sumimakito, @luoling8192, @LemonNekoGH, kindly take some time to review and approve this deployment when you are available. Thank you! 🙏

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be better if we could provide such API in both Electron and remote API way?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've switched to xsai, which can be used directly in Electron. The api/ directory is designed as a supplementary solution for remote API scenarios. Users can choose either local or remote approach based on their needs.


import { randomUUID } from 'node:crypto'

import OpenAI from 'openai'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should switch to xsai.js.org

Comment on lines 11 to 15
import { QdrantClient } from '@qdrant/js-client-rest'
import { Redis } from '@upstash/redis'
import { kv as vercelKv } from '@vercel/kv'
import { sql } from '@vercel/postgres'
import { Pool } from 'pg'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have something like storage-backends or abstractions for these, to support different storages. Do LlamaIndex provide enough design reference for this? At least they are company and hopefully their package will last one enough.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have IMemoryProvider as the storage abstraction layer. I'm more familiar with LangChain. do you have specific design suggestions?

Comment on lines +8 to +19
CREATE TABLE IF NOT EXISTS memory_embeddings (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
content TEXT NOT NULL,
embedding vector(768), -- Adjust dimension based on your embedding model
metadata JSONB,
created_at TIMESTAMPTZ DEFAULT NOW(),

-- Indexes for performance
INDEX idx_memory_user_id (user_id),
INDEX idx_memory_created_at (created_at DESC)
);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design here isn't the optimal. But we can go first with this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this file.

Comment on lines +1 to +37
import type { VercelRequest, VercelResponse } from '@vercel/node'

import { clearSession } from '../_lib/memory'

export default async function handler(req: VercelRequest, res: VercelResponse) {
// Set CORS headers
res.setHeader('Access-Control-Allow-Origin', '*')
res.setHeader('Access-Control-Allow-Methods', 'POST, OPTIONS')
res.setHeader('Access-Control-Allow-Headers', 'Content-Type')

if (req.method === 'OPTIONS') {
return res.status(200).end()
}

if (req.method !== 'POST') {
return res.status(405).json({ success: false, error: 'Method not allowed' })
}

try {
const body = req.body as { sessionId?: string }

if (!body?.sessionId || typeof body.sessionId !== 'string') {
return res.status(400).json({ success: false, error: 'sessionId is required' })
}

await clearSession(body.sessionId)

return res.status(200).json({ success: true })
}
catch (error) {
console.error('Error in /api/memory/clear:', error)
return res.status(500).json({
success: false,
error: error instanceof Error ? error.message : String(error),
})
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, why don't we use Nitro or Hono?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current code has too much CORS boilerplate. I'll migrate to Hono in a follow-up PR - it's a small change since. business logic is already separated in _lib/memory.ts.

Comment on lines +6 to +13
// Set CORS headers
res.setHeader('Access-Control-Allow-Origin', '*')
res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
res.setHeader('Access-Control-Allow-Headers', 'Content-Type')

if (req.method === 'OPTIONS') {
return res.status(200).end()
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems CORS is occurring in multiple files, middleware should be used.

Comment on lines +22 to +26
console.info('[Memory Debug Search] 收到搜索请求:', {
query: body?.query,
userId: body?.userId,
limit: body?.limit,
})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug console? Should be removed?

"name": "@proj-airi/api",
"version": "0.7.2-beta.3",
"private": true,
"description": "Vercel serverless API functions for AIRI",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it works only on Vercel?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's Vercel-oriented. Old habit from my serverless days. Now that we're using xsai, local deployment should be easier to support. Will work on it in a follow-up PR.

Comment on lines +160 to +185
private async ensureTable(): Promise<void> {
await this.pool.query(
`CREATE TABLE IF NOT EXISTS ${this.tableName} (
id BIGSERIAL PRIMARY KEY,
user_id TEXT NOT NULL,
session_id TEXT,
role TEXT,
content TEXT NOT NULL,
metadata JSONB,
embedding vector(${this.embeddingDimensions}) NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);`,
)
}

private async ensureIndexes(): Promise<void> {
await this.pool.query(`CREATE INDEX IF NOT EXISTS ${this.tableName}_user_id_idx ON ${this.tableName} (user_id);`)
await this.pool.query(`CREATE INDEX IF NOT EXISTS ${this.tableName}_session_id_idx ON ${this.tableName} (session_id);`)
await this.pool.query(`CREATE INDEX IF NOT EXISTS ${this.tableName}_created_at_idx ON ${this.tableName} (created_at DESC);`)
await this.pool.query(
`CREATE INDEX IF NOT EXISTS ${this.tableName}_embedding_idx
ON ${this.tableName}
USING ivfflat (embedding vector_l2_ops)
WITH (lists = 100);`,
)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to take care of this? Shouldn't drizzle helped us?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not using Drizzle elsewhere in the project yet. Current focus is getting full functionality working in serverless first. Will consider adding Drizzle later for better maintainability.

Comment on lines +205 to +238
private async generateEmbedding(input: string): Promise<number[]> {
const trimmed = input.trim()
if (trimmed.length === 0) {
return Array.from({ length: this.vectorSize }, () => 0)
}

let embedding: number[] = []

if (this.embeddingConfig.provider === 'cloudflare') {
embedding = await this.generateCloudflareEmbedding(trimmed)
}
else {
if (!this.openai) {
throw new Error('OpenAI-compatible client is not configured for embeddings.')
}

const response = await this.openai.embeddings.create({
model: this.embeddingConfig.model ?? DEFAULT_EMBEDDING_MODEL,
input: trimmed,
})

embedding = response.data?.[0]?.embedding ?? []
}

if (!embedding.length) {
throw new Error('Failed to generate embedding for memory content.')
}

if (embedding.length !== this.vectorSize) {
throw new Error(`Embedding dimension mismatch. Expected ${this.vectorSize}, received ${embedding.length}.`)
}

return embedding.map((value: number | string) => Number(value))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched.

Comment on lines +258 to +291
private resolveEmbeddingConfiguration(
config?: EmbeddingProviderConfiguration,
): EmbeddingProviderConfiguration {
const provider = config?.provider
?? (env.MEMORY_EMBEDDING_PROVIDER as EmbeddingProviderConfiguration['provider'] | undefined)
?? 'openai'

const apiKey = config?.apiKey
?? env.MEMORY_EMBEDDING_API_KEY
?? env.OPENAI_API_KEY

if (!apiKey) {
throw new Error('An embedding API key is required.')
}

const model = config?.model
?? env.MEMORY_EMBEDDING_MODEL
?? DEFAULT_EMBEDDING_MODEL

const baseUrl = config?.baseUrl ?? env.MEMORY_EMBEDDING_BASE_URL
const accountId = config?.accountId ?? env.CLOUDFLARE_ACCOUNT_ID

if (provider === 'cloudflare' && !accountId) {
throw new Error('Cloudflare embedding provider requires an account ID.')
}

return {
provider,
apiKey,
model,
baseUrl,
accountId,
} satisfies EmbeddingProviderConfiguration
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be done before application start?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For serverless, runtime check feels safer since cold starts can happen anytime. But yeah, maybe could add a separate setup script for traditional deployments.

Comment on lines 293 to 323
private async generateCloudflareEmbedding(input: string): Promise<number[]> {
const accountId = this.embeddingConfig.accountId
if (!accountId) {
throw new Error('Cloudflare account ID is not configured.')
}

const url = `${this.embeddingConfig.baseUrl ?? 'https://api.cloudflare.com/client/v4'}/accounts/${accountId}/ai/run/${this.embeddingConfig.model ?? DEFAULT_EMBEDDING_MODEL}`

const response = await fetch(url, {
method: 'POST',
headers: {
'Authorization': `Bearer ${this.embeddingConfig.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({ text: input }),
})

if (!response.ok) {
const body = await response.text()
throw new Error(`Cloudflare embedding request failed: ${response.status} ${body}`)
}

const payload = await response.json() as Record<string, any>
const embedding = this.extractCloudflareEmbedding(payload)

if (!embedding || !Array.isArray(embedding) || embedding.length === 0) {
throw new Error('Cloudflare embedding response did not include an embedding vector.')
}

return embedding.map((value: number | string) => Number(value))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://xsai.js.org/docs/packages-ext/providers got Cloudflare supported well and enough.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use xsai now. The Cloudflare implementation now uses xsai's embed function with Cloudflare Workers AI's OpenAI-compatible API endpoint (/accounts/{accountId}/ai/v1), instead of manual fetch calls.

Refactor Cloudflare embedding implementation to use xsai's embed
function with Cloudflare's OpenAI-compatible API endpoint instead
of manual fetch calls.

Changes:
- Use OpenAI-compatible endpoint: /accounts/{id}/ai/v1
- Replace manual fetch with xsai embed function
- Remove extractCloudflareEmbedding helper functions
- Add DEFAULT_CLOUDFLARE_EMBEDDING_MODEL constant
- Default model changed to @cf/baai/bge-base-en-v1.5

This addresses the PR review feedback suggesting to use xsai's
providers for Cloudflare support.

Ref: https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants